Set Implicit Arguments.

Require Import FCF.FCF.
Require Import FCF.CompFold.
Require Import FCF.PRF.
Require Import hmacdrbg.PRG_NA.

Local Open Scope list_scope.
Section HMAC_DRBG_PRG_NA.

(* HMAC-DRBG spec *)

(* The security parameter eta determines the size of PRF keys *)
Variable eta : nat.
Variable eta_nz : eta <> O.
(* The function f models HMAC *)
Variable f : Bvector eta -> Blist -> Bvector eta.

Definition RndK : Comp (Bvector eta) := {0,1}^eta.
Definition RndV : Comp (Bvector eta) := {0,1}^eta.
Definition KV : Set := (Bvector eta * Bvector eta)%type.

(* The Instantiate function *)
(* NOTE: does not reflect NIST spec *)
Definition Instantiate : Comp KV :=
  k <-$ RndK;
  v <-$ RndV;
  ret (k, v).

(* The Generate function *)
Definition to_list (A : Type) (n : nat) (v : Vector.t A n) := Vector.to_list v.

Fixpoint Gen_loop (k : Bvector eta) (v : Bvector eta) (n : nat)
  : list (Bvector eta) * Bvector eta :=
  match n with
  | O => (nil, v)
  | S n' =>
    let v' := f k (to_list v) in
    let (bits, v'') := Gen_loop k v' n' in
    (v' :: bits, v'')
  end.

(* Spec says "V || 0x00"; here we will use a list of 8 bits of 0 (a byte) *)
Fixpoint replicate {A} (n : nat) (a : A) : list A :=
  match n with
  | O => nil
  | S n' => a :: replicate n' a
  end.

Definition zeroes : list bool := replicate 8 false.

Definition Generate (state : KV) (n : nat) :
  Comp (list (Bvector eta) * KV) :=
  [k, v] <-2 state;
  [bits, v'] <-2 Gen_loop k v n;
  k' <- f k (to_list v' ++ zeroes);
  v'' <- f k' (to_list v');
  ret (bits, (k', v'')).


(* The adversary against the PRG *)
Variable blocksPerCall : nat.       (* blocks generated by GenLoop *)
Variable blocksPerCall_gt_0 : blocksPerCall > O.
Variable numCalls : nat.        (* number of calls to Generate *)
Variable numCalls_gt_0: numCalls > O.
Definition requestList : list nat := replicate numCalls blocksPerCall.
Variable A : list (list (Bvector eta)) -> Comp bool.
Variable A_wf : forall ls, well_formed_comp (A ls).

(* The constructed adversary against the PRF *)
Fixpoint oracleCompMap_inner {D R OracleIn OracleOut : Set} 
           (e1 : EqDec ((list R) * (nat * KV))) 
           (e2 : EqDec (list R))
           (* this is an oracleComp, not an oracle *)
           (* the oracle has type (D * R) -> D -> Comp (R, (D * R)) *)
           (oracleComp : (nat * KV) -> D -> OracleComp OracleIn OracleOut (R * (nat * KV))) 
           (state : (nat * KV)) (* note this state type -- it is EXPLICITLY being passed around *)
           (inputs : list D) : OracleComp OracleIn OracleOut (list R * (nat * KV)) :=
  match inputs with
  | nil => $ ret (nil, state)
  | input :: inputs' => 
    [res, state'] <--$2 oracleComp state input;
    [resList, state''] <--$2 oracleCompMap_inner _ _ oracleComp state' inputs';
    $ ret (res :: resList, state'')
  end.

Definition oracleCompMap_outer {D R OracleIn OracleOut : Set} 
           (e1 : EqDec ((list R) * (nat * KV))) 
           (e2 : EqDec (list R))
           (oracleComp : (nat * KV) -> D -> OracleComp OracleIn OracleOut (R * (nat * KV)))
           (inputs : list D) : OracleComp OracleIn OracleOut (list R) :=
  [k, v] <--$2 $ Instantiate;   (* generate state inside, instead of being passed state *)
  [bits, _] <--$2 oracleCompMap_inner _ _ oracleComp (O, (k, v)) inputs;
  (* the "_" here has type (nat * KV) *)
  $ ret bits.

Definition Generate_v_PRF_oc (state : KV) (n : nat) :
  OracleComp (list bool) (Bvector eta) (list (Bvector eta) * KV) :=
  [k, v] <-2 state;
  v' <- f k (to_list v);
  [bits, v''] <-2 Gen_loop k v' n;
  k' <- f k (to_list v'' ++ zeroes);
  $ ret (bits, (k', v'')).

Fixpoint Gen_loop_oc (v : Bvector eta) (n : nat)
  : OracleComp (list bool) (Bvector eta) (list (Bvector eta) * Bvector eta) :=
  match n with
  | O => $ ret (nil, v)
  | S n' =>
    v' <--$ (OC_Query _ (to_list v)); (* ORACLE USE *)
    [bits, v''] <--$2 Gen_loop_oc v' n';
    $ ret (v' :: bits, v'')
  end.

Definition Generate_v_oc (state : KV) (n : nat) :
  OracleComp (list bool) (Bvector eta) (list (Bvector eta) * KV) :=
  [k, v_0] <-2 state;
  v <--$ (OC_Query _ (to_list v_0)); (* ORACLE USE *)
  [bits, v'] <--$2 Gen_loop_oc v n;
  (* TODO what's the state type here? and the global Generate_v_oc return type? *)
  k' <--$ (OC_Query _ (to_list v' ++ zeroes)); (* ORACLE USE *)
  $ ret (bits, (k', v')).

Definition Generate_noV_oc (state : KV) (n : nat) :
  OracleComp (list bool) (Bvector eta)  (list (Bvector eta) * KV) :=
  [k, v] <-2 state;
  [bits, v'] <--$2 Gen_loop_oc v n;
  (* TODO what's the state type here? and the global Generate_v_oc return type? *)
  k' <--$ (OC_Query _ (to_list v' ++ zeroes)); (* ORACLE USE *)
  $ ret (bits, (k', v')).

Fixpoint Gen_loop_rb_intermediate (k : Bvector eta) (v : Bvector eta) (n : nat)
  : Comp (list (Bvector eta) * Bvector eta) :=
  match n with
  | O => ret (nil, v)
  | S n' =>
    v' <-$ {0,1}^eta;
    [bits, v''] <-$2 Gen_loop_rb_intermediate k v' n';
    ret (v' :: bits, v'')
  end.

Definition Generate_rb_intermediate_oc (state : KV) (n : nat) 
  : OracleComp (list bool) (Bvector eta) (list (Bvector eta) * KV) :=
  [k, v] <-2 state;
  v' <--$ $ {0,1}^eta;
  [bits, v''] <--$2 $ Gen_loop_rb_intermediate k v' n;    (* promote comp to oraclecomp, then remove from o.c. *)
  $ ret (bits, (k, v'')).

Definition Oi_oc' (i : nat) (sn : nat * KV) (n : nat) 
  : OracleComp Blist (Bvector eta) (list (Bvector eta) * (nat * KV)) :=
  [callsSoFar, state] <-2 sn;
  let Generate_v_choose :=
      (* this behavior (applied with f_oracle) needs to match that of choose_Generate's *)
      if lt_dec callsSoFar i (* callsSoFar < i (override all else) *)
           then Generate_rb_intermediate_oc (* this implicitly has no v to update *)
      else if beq_nat callsSoFar O (* use oracle on 1st call w/o updating v *)
           then Generate_noV_oc 
      else if beq_nat callsSoFar i (* callsSoFar = i *)
           then Generate_v_oc    (* uses provided oracle (PRF or RF) *)
      else Generate_v_PRF_oc in        (* uses PRF with (k,v) updating *)
  [bits, state'] <--$2 Generate_v_choose state n;
  $ ret (bits, (S callsSoFar, state')).

Definition PRF_Adversary (i : nat) : OracleComp Blist (Bvector eta) bool :=
  bits <--$ oracleCompMap_outer _ _ (Oi_oc' i) requestList;
  $ A bits.
(* End constructed adversary definition *)

Definition Pr_collisions := (S blocksPerCall)^2 / 2^eta.
Definition PRF_Advantage_Game i : Rat := 
  PRF_Advantage RndK ({0,1}^eta) f _ _ (PRF_Adversary i).
Fixpoint argMax(f : nat -> Rat) (n : nat) :=
  match n with
    | O => O
    | S n' => let p := (argMax f n') in
              if (le_Rat_dec (f (S n')) (f p)) then p else (S n')
                                                             end.

Definition PRF_Advantage_Max := PRF_Advantage_Game (argMax PRF_Advantage_Game numCalls).
Definition Gi_Gi_plus_1_bound := PRF_Advantage_Max + Pr_collisions.

(* The desired security property: HMAC_DRBG is a non-adaptive PRG *)
Definition HMAC_DRBG_PRG_NA :=
  PRG_Nonadaptive_Advantage _ _ ({0,1}^eta) Instantiate Generate 
  (ret requestList) A <= (numCalls / 1) * Gi_Gi_plus_1_bound.

(* The proof of this property is imported and applied below *)
Require Import hmacdrbg.HMAC_DRBG_nonadaptive.

Theorem Generate_rb_eq_ideal : 
  forall b x,
  evalDist (Generate_ideal (Bvector_EqDec eta) ({ 0 , 1 }^ eta) b) x ==
  evalDist (Generate_rb eta b) x.
  
  unfold Generate_rb, Generate_ideal in *.
  Local Opaque evalDist.
  induction b; intuition; simpl in *.
  - fcf_simp.
    reflexivity.

  - fcf_inline_first.
    fcf_skip.
    fcf_inline_first.
    fcf_skip.
    rewrite IHb.
    rewrite evalDist_right_ident.
    reflexivity.
    rewrite evalDist_right_ident.
    reflexivity.
Qed.

Theorem PRG_Advantage_eq : 
  PRG_Nonadaptive_Advantage _ _ ({0,1}^eta) HMAC_DRBG_PRG_NA.Instantiate HMAC_DRBG_PRG_NA.Generate 
  (ret HMAC_DRBG_PRG_NA.requestList) A == 
  | Pr [G_real f _ A blocksPerCall numCalls ] -
       Pr [G_ideal A blocksPerCall numCalls ] |.

  unfold PRG_Nonadaptive_Advantage.
  apply ratDistance_eqRat_compat.
  - unfold PRG_G1, G_real.
    fcf_skip.
    reflexivity.
    fcf_simp.
    reflexivity.

  - unfold PRG_G2, G_ideal.
    fcf_simp.
    fcf_skip.
    eapply compMap_eq; intuition.
    apply list_pred_eq.
    subst.
    apply Generate_rb_eq_ideal.
Qed.

Theorem HMAC_DRBG_PRG_NA_true : HMAC_DRBG_PRG_NA.

  unfold HMAC_DRBG_PRG_NA, PRG_Nonadaptive_Advantage.
  eapply leRat_trans.
  apply eqRat_impl_leRat.
  apply PRG_Advantage_eq.
  apply G1_G2_close; intuition.

Qed.
